package math

import (
	"encoding"
	"encoding/json"
	"errors"
	"fmt"
	"math/big"
	"math/bits"
	"strings"
	"sync"
	"testing"
)

// MaxBitLen defines the maximum bit length supported bit Int and Uint types.
const MaxBitLen = 256

// maxWordLen defines the maximum word length supported by Int and Uint types.
// We check overflow, by first doing a fast check if the word length is below maxWordLen
// and if not then do the slower full bitlen check.
// NOTE: If MaxBitLen is not a multiple of bits.UintSize, then we need to edit the used logic slightly.
const maxWordLen = MaxBitLen / bits.UintSize

// Integer errors
var (
	// ErrIntOverflow is the error returned when an integer overflow occurs
	ErrIntOverflow = errors.New("integer overflow")
	// ErrDivideByZero is the error returned when a divide by zero occurs
	ErrDivideByZero = errors.New("divide by zero")
)

func newIntegerFromString(s string) (*big.Int, bool) {
	return new(big.Int).SetString(s, 0)
}

func equal(i, i2 *big.Int) bool { return i.Cmp(i2) == 0 }

func gt(i, i2 *big.Int) bool { return i.Cmp(i2) == 1 }

func gte(i, i2 *big.Int) bool { return i.Cmp(i2) >= 0 }

func lt(i, i2 *big.Int) bool { return i.Cmp(i2) == -1 }

func lte(i, i2 *big.Int) bool { return i.Cmp(i2) <= 0 }

func add(i, i2 *big.Int) *big.Int { return new(big.Int).Add(i, i2) }

func sub(i, i2 *big.Int) *big.Int { return new(big.Int).Sub(i, i2) }

func mul(i, i2 *big.Int) *big.Int { return new(big.Int).Mul(i, i2) }

func div(i, i2 *big.Int) *big.Int { return new(big.Int).Quo(i, i2) }

func mod(i, i2 *big.Int) *big.Int { return new(big.Int).Mod(i, i2) }

func neg(i *big.Int) *big.Int { return new(big.Int).Neg(i) }

func abs(i *big.Int) *big.Int { return new(big.Int).Abs(i) }

func min(i, i2 *big.Int) *big.Int {
	if i.Cmp(i2) == 1 {
		return new(big.Int).Set(i2)
	}

	return new(big.Int).Set(i)
}

func max(i, i2 *big.Int) *big.Int {
	if i.Cmp(i2) == -1 {
		return new(big.Int).Set(i2)
	}

	return new(big.Int).Set(i)
}

func unmarshalText(i *big.Int, text string) error {
	if err := i.UnmarshalText([]byte(text)); err != nil {
		return err
	}

	if bigIntOverflows(i) {
		return fmt.Errorf("integer out of range: %s", text)
	}

	return nil
}

var _ customProtobufType = (*Int)(nil)

// Int wraps big.Int with a 256 bit range bound
// Checks overflow, underflow and division by zero
// Exists in range from -(2^256 - 1) to 2^256 - 1
type Int struct {
	i *big.Int
}

// BigInt converts Int to big.Int
func (i Int) BigInt() *big.Int {
	if i.IsNil() {
		return nil
	}
	return new(big.Int).Set(i.i)
}

// BigIntMut converts Int to big.Int, mutative the input
func (i Int) BigIntMut() *big.Int {
	if i.IsNil() {
		return nil
	}
	return i.i
}

// IsNil returns true if Int is uninitialized
func (i Int) IsNil() bool {
	return i.i == nil
}

// NewInt constructs Int from int64
func NewInt(n int64) Int {
	return Int{big.NewInt(n)}
}

// NewIntFromUint64 constructs an Int from a uint64.
func NewIntFromUint64(n uint64) Int {
	b := big.NewInt(0)
	b.SetUint64(n)
	return Int{b}
}

// NewIntFromBigInt constructs Int from big.Int. If the provided big.Int is nil,
// it returns an empty instance. This function panics if the bit length is > 256.
// Note, the caller can safely mutate the argument after this function returns.
func NewIntFromBigInt(i *big.Int) Int {
	if i == nil {
		return Int{}
	}

	if bigIntOverflows(i) {
		panic("NewIntFromBigInt() out of bound")
	}

	return Int{new(big.Int).Set(i)}
}

// NewIntFromBigIntMut constructs Int from big.Int. If the provided big.Int is nil,
// it returns an empty instance. This function panics if the bit length is > 256.
// Note, this function mutate the argument.
func NewIntFromBigIntMut(i *big.Int) Int {
	if i == nil {
		return Int{}
	}

	if bigIntOverflows(i) {
		panic("NewIntFromBigInt() out of bound")
	}

	return Int{i}
}

// NewIntFromString constructs Int from string
func NewIntFromString(s string) (res Int, ok bool) {
	i, ok := newIntegerFromString(s)
	if !ok {
		return
	}
	// Check overflow
	if bigIntOverflows(i) {
		ok = false
		return
	}
	return Int{i}, true
}

// NewIntWithDecimal constructs Int with decimal
// Result value is n*10^dec
func NewIntWithDecimal(n int64, dec int) Int {
	if dec < 0 {
		panic("NewIntWithDecimal() decimal is negative")
	}
	exp := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(dec)), nil)
	i := new(big.Int)
	i.Mul(big.NewInt(n), exp)

	// Check overflow
	if bigIntOverflows(i) {
		panic("NewIntWithDecimal() out of bound")
	}
	return Int{i}
}

// ZeroInt returns Int value with zero
func ZeroInt() Int { return Int{big.NewInt(0)} }

// OneInt returns Int value with one
func OneInt() Int { return Int{big.NewInt(1)} }

// ToLegacyDec converts Int to LegacyDec
func (i Int) ToLegacyDec() LegacyDec {
	return LegacyNewDecFromInt(i)
}

// Int64 converts Int to int64
// Panics if the value is out of range
func (i Int) Int64() int64 {
	if !i.i.IsInt64() {
		panic("Int64() out of bound")
	}
	return i.i.Int64()
}

// IsInt64 returns true if Int64() not panics
func (i Int) IsInt64() bool {
	return i.i.IsInt64()
}

// Uint64 converts Int to uint64
// Panics if the value is out of range
func (i Int) Uint64() uint64 {
	if !i.i.IsUint64() {
		panic("Uint64() out of bounds")
	}
	return i.i.Uint64()
}

// IsUint64 returns true if Uint64() not panics
func (i Int) IsUint64() bool {
	return i.i.IsUint64()
}

// IsZero returns true if Int is zero
func (i Int) IsZero() bool {
	return i.i.Sign() == 0
}

// IsNegative returns true if Int is negative
func (i Int) IsNegative() bool {
	return i.i.Sign() == -1
}

// IsPositive returns true if Int is positive
func (i Int) IsPositive() bool {
	return i.i.Sign() == 1
}

// Sign returns sign of Int
func (i Int) Sign() int {
	return i.i.Sign()
}

// Equal compares two Ints
func (i Int) Equal(i2 Int) bool {
	return equal(i.i, i2.i)
}

// GT returns true if first Int is greater than second
func (i Int) GT(i2 Int) bool {
	return gt(i.i, i2.i)
}

// GTE returns true if receiver Int is greater than or equal to the parameter
// Int.
func (i Int) GTE(i2 Int) bool {
	return gte(i.i, i2.i)
}

// LT returns true if first Int is lesser than second
func (i Int) LT(i2 Int) bool {
	return lt(i.i, i2.i)
}

// LTE returns true if first Int is less than or equal to second
func (i Int) LTE(i2 Int) bool {
	return lte(i.i, i2.i)
}

// Add adds Int from another
func (i Int) Add(i2 Int) (res Int) {
	// Check overflow
	x, err := i.SafeAdd(i2)
	if err != nil {
		panic(err)
	}
	return x
}

// AddRaw adds int64 to Int
func (i Int) AddRaw(i2 int64) Int {
	return i.Add(NewInt(i2))
}

// SafeAdd adds Int from another and returns an error if overflow
func (i Int) SafeAdd(i2 Int) (res Int, err error) {
	res = Int{add(i.i, i2.i)}
	// Check overflow
	if bigIntOverflows(res.i) {
		return Int{}, ErrIntOverflow
	}
	return res, nil
}

// Sub subtracts Int from another
func (i Int) Sub(i2 Int) (res Int) {
	// Check overflow
	x, err := i.SafeSub(i2)
	if err != nil {
		panic(err)
	}
	return x
}

// SubRaw subtracts int64 from Int
func (i Int) SubRaw(i2 int64) Int {
	return i.Sub(NewInt(i2))
}

// SafeSub subtracts Int from another and returns an error if overflow or underflow
func (i Int) SafeSub(i2 Int) (res Int, err error) {
	res = Int{sub(i.i, i2.i)}
	// Check overflow/underflow
	if bigIntOverflows(res.i) {
		return Int{}, ErrIntOverflow
	}
	return res, nil
}

// Mul multiples two Ints
func (i Int) Mul(i2 Int) (res Int) {
	// Check overflow
	x, err := i.SafeMul(i2)
	if err != nil {
		panic(err)
	}
	return x
}

// MulRaw multiplies Int and int64
func (i Int) MulRaw(i2 int64) Int {
	return i.Mul(NewInt(i2))
}

// SafeMul multiples Int from another and returns an error if overflow
func (i Int) SafeMul(i2 Int) (res Int, err error) {
	res = Int{mul(i.i, i2.i)}
	// Check overflow
	if bigIntOverflows(res.i) {
		return Int{}, ErrIntOverflow
	}
	return res, nil
}

// Quo divides Int with Int
func (i Int) Quo(i2 Int) (res Int) {
	// Check division-by-zero
	x, err := i.SafeQuo(i2)
	if err != nil {
		panic("Division by zero")
	}
	return x
}

// QuoRaw divides Int with int64
func (i Int) QuoRaw(i2 int64) Int {
	return i.Quo(NewInt(i2))
}

// SafeQuo divides Int with Int and returns an error if division by zero
func (i Int) SafeQuo(i2 Int) (res Int, err error) {
	// Check division-by-zero
	if i2.i.Sign() == 0 {
		return Int{}, ErrDivideByZero
	}
	return Int{div(i.i, i2.i)}, nil
}

// Mod returns remainder after dividing with Int
func (i Int) Mod(i2 Int) Int {
	x, err := i.SafeMod(i2)
	if err != nil {
		panic(err)
	}
	return x
}

// ModRaw returns remainder after dividing with int64
func (i Int) ModRaw(i2 int64) Int {
	return i.Mod(NewInt(i2))
}

// SafeMod returns remainder after dividing with Int and returns an error if division by zero
func (i Int) SafeMod(i2 Int) (res Int, err error) {
	if i2.Sign() == 0 {
		return Int{}, ErrDivideByZero
	}
	return Int{mod(i.i, i2.i)}, nil
}

// Neg negates Int
func (i Int) Neg() (res Int) {
	return Int{neg(i.i)}
}

// Abs returns the absolute value of Int.
func (i Int) Abs() Int {
	return Int{abs(i.i)}
}

// MinInt return the minimum of the ints
func MinInt(i1, i2 Int) Int {
	return Int{min(i1.BigInt(), i2.BigInt())}
}

// MaxInt returns the maximum between two integers.
func MaxInt(i, i2 Int) Int {
	return Int{max(i.BigInt(), i2.BigInt())}
}

// String returns human-readable string
func (i Int) String() string {
	return i.i.String()
}

// MarshalJSON defines custom encoding scheme
func (i Int) MarshalJSON() ([]byte, error) {
	if i.i == nil { // Necessary since default Uint initialization has i.i as nil
		i.i = new(big.Int)
	}
	return marshalJSON(i.i)
}

// UnmarshalJSON defines custom decoding scheme
func (i *Int) UnmarshalJSON(bz []byte) error {
	if i.i == nil { // Necessary since default Int initialization has i.i as nil
		i.i = new(big.Int)
	}
	return unmarshalJSON(i.i, bz)
}

// marshalJSON for custom encoding scheme
// Must be encoded as a string for JSON precision
func marshalJSON(i encoding.TextMarshaler) ([]byte, error) {
	text, err := i.MarshalText()
	if err != nil {
		return nil, err
	}

	return json.Marshal(string(text))
}

// unmarshalJSON for custom decoding scheme
// Must be encoded as a string for JSON precision
func unmarshalJSON(i *big.Int, bz []byte) error {
	var text string
	if err := json.Unmarshal(bz, &text); err != nil {
		return err
	}

	return unmarshalText(i, text)
}

// MarshalYAML returns the YAML representation.
func (i Int) MarshalYAML() (interface{}, error) {
	return i.String(), nil
}

// Marshal implements the gogo proto custom type interface.
func (i Int) Marshal() ([]byte, error) {
	if i.i == nil {
		i.i = new(big.Int)
	}
	return i.i.MarshalText()
}

// MarshalTo implements the gogo proto custom type interface.
func (i *Int) MarshalTo(data []byte) (n int, err error) {
	if i.i == nil {
		i.i = new(big.Int)
	}
	if i.i.BitLen() == 0 { // The value 0
		n = copy(data, []byte{0x30})
		return n, nil
	}

	bz, err := i.Marshal()
	if err != nil {
		return 0, err
	}

	n = copy(data, bz)
	return n, nil
}

// Unmarshal implements the gogo proto custom type interface.
func (i *Int) Unmarshal(data []byte) error {
	if len(data) == 0 {
		i = nil
		return nil
	}

	if i.i == nil {
		i.i = new(big.Int)
	}

	if err := i.i.UnmarshalText(data); err != nil {
		return err
	}

	if bigIntOverflows(i.i) {
		return fmt.Errorf("integer out of range; got: %d, max: %d", i.i.BitLen(), MaxBitLen)
	}

	return nil
}

// Size implements the gogo proto custom type interface.
func (i *Int) Size() int {
	bz, _ := i.Marshal()
	return len(bz)
}

// MarshalAmino Override Amino binary serialization by proxying to protobuf.
func (i Int) MarshalAmino() ([]byte, error)   { return i.Marshal() }
func (i *Int) UnmarshalAmino(bz []byte) error { return i.Unmarshal(bz) }

// IntEq intended to be used with require/assert:  require.True(IntEq(...))
func IntEq(t *testing.T, exp, got Int) (*testing.T, bool, string, string, string) {
	t.Helper()
	return t, exp.Equal(got), "expected:\t%v\ngot:\t\t%v", exp.String(), got.String()
}

func hasOnlyDigits(s string) bool {
	if s == "" {
		return false
	}
	for _, r := range s {
		if r < '0' || r > '9' {
			return false
		}
	}
	return true
}

const thousandSeparator string = "'"

var stringsBuilderPool = &sync.Pool{
	New: func() any { return new(strings.Builder) },
}

// FormatInt formats an integer (encoded as in protobuf) into a value-rendered
// string following ADR-050. This function operates with string manipulation
// (instead of manipulating the int or math.Int object).
func FormatInt(v string) (string, error) {
	if len(v) == 0 {
		return "", errors.New("cannot format empty string")
	}

	sign := ""
	if v[0] == '-' {
		sign = "-"
		v = v[1:]
	}
	if len(v) > 1 {
		v = strings.TrimLeft(v, "0")
	}

	// Ensure that the string contains only digits at this point.
	if !hasOnlyDigits(v) {
		return "", fmt.Errorf("expecting only digits 0-9, but got non-digits in %q", v)
	}

	// 1. Less than 4 digits don't need any formatting.
	if len(v) <= 3 {
		return sign + v, nil
	}

	sb := stringsBuilderPool.Get().(*strings.Builder)
	defer stringsBuilderPool.Put(sb)
	sb.Reset()
	sb.Grow(len(v) + len(v)/3) // Exactly v + numberOfThousandSeparatorsIn(v)

	// 2. If the length of v is not a multiple of 3 e.g. 1234 or 12345, to achieve 1'234 or 12'345,
	// we can simply slide to the first mod3 values of v that aren't the multiples of 3 then insert in
	// the thousands separator so in this case: write(12'); then the remaining v will be entirely multiple
	// of 3 hence v = 34*
	if mod3 := len(v) % 3; mod3 != 0 {
		sb.WriteString(v[:mod3])
		v = v[mod3:]
		sb.WriteString(thousandSeparator)
	}

	// 3. By this point v is entirely multiples of 3 hence we just insert the separator at every 3 digit.
	for i := 0; i < len(v); i += 3 {
		end := i + 3
		sb.WriteString(v[i:end])
		if end < len(v) {
			sb.WriteString(thousandSeparator)
		}
	}

	return sign + sb.String(), nil
}

// check if the big int overflows.
func bigIntOverflows(i *big.Int) bool {
	// overflow is defined as i.BitLen() > MaxBitLen
	// however this check can be expensive when doing many operations.
	// So we first check if the word length is greater than maxWordLen.
	// However the most significant word could be zero, hence we still do the bitlen check.
	if len(i.Bits()) > maxWordLen {
		return i.BitLen() > MaxBitLen
	}
	return false
}
